import math
import numpy as np
from PIL import Image
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from .randaugment import RandAugmentMC


cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
cifar100_mean = (0.5071, 0.4867, 0.4408)
cifar100_std = (0.2675, 0.2565, 0.2761)
normal_mean = (0.5, 0.5, 0.5)
normal_std = (0.5, 0.5, 0.5)


def get_cifar10_fixmatch(root,
                         label_num,
                         batch_size,
                         unlabel_batch_size
                         ):

    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32, padding=int(32 * 0.125), padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])

    base_dataset = datasets.CIFAR10(root, train=True, download=True)

    train_labeled_idxs, train_unlabeled_idxs = x_u_split(label_num, 10, batch_size, base_dataset.targets)
    train_labeled_dataset = CIFAR10SSL(root, train_labeled_idxs, train=True, transform=transform_labeled)
    train_unlabeled_dataset = CIFAR10SSL(root, train_unlabeled_idxs, train=True, transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
    test_dataset = datasets.CIFAR10(root, train=False, transform=transform_val, download=False)

    labeled_loader = DataLoader(train_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True,)
    unlabeled_loader = DataLoader(train_unlabeled_dataset, batch_size=unlabel_batch_size, shuffle=True, drop_last=True,)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return labeled_loader, unlabeled_loader, test_loader


def get_cifar100_fixmatch(root, label_num, batch_size, unlabel_batch_size):
    """
    :param root: 路径地址
    :param num_labeled: 使用的标签数量
    :param batch_size: 一次的batch
    :return:
    """
    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32, padding=int(32 * 0.125), padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]
    )

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]
    )

    base_dataset = datasets.CIFAR100(root, train=True, download=True)

    train_labeled_idxs, train_unlabeled_idxs = x_u_split(label_num, 100, batch_size, base_dataset.targets)
    train_labeled_dataset = CIFAR100SSL(root, train_labeled_idxs, train=True, transform=transform_labeled)
    train_unlabeled_dataset = CIFAR100SSL(root, train_unlabeled_idxs, train=True, transform=TransformFixMatch(mean=cifar100_mean, std=cifar100_std))
    test_dataset = datasets.CIFAR100(root, train=False, transform=transform_val, download=False)

    labeled_loader = DataLoader(train_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True,)
    unlabeled_loader = DataLoader(train_unlabeled_dataset, batch_size=unlabel_batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return labeled_loader, unlabeled_loader, test_loader


def x_u_split(num_labeled, num_classes, batch_size, labels):
    # label_per_class 每个类别所含有的数目
    label_per_class = num_labeled // num_classes
    labels = np.array(labels)
    labeled_idx = []
    # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
    unlabeled_idx = np.array(range(len(labels)))
    for i in range(num_classes):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class, False)
        labeled_idx.extend(idx)
    labeled_idx = np.array(labeled_idx)
    assert len(labeled_idx) == num_labeled

    # #  是否扩充数据集
    # if expand_labels or num_labeled < batch_size:
    #     num_expand_x = math.ceil(batch_size * eval_step / num_labeled) # 这里是为了让 标注数据和 无标注数据一样
    #     print("num_expand_x", num_expand_x)
    #     labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
    np.random.shuffle(labeled_idx)

    return labeled_idx, unlabeled_idx


class TransformFixMatch(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32, padding=int(32 * 0.125), padding_mode='reflect')]
        )

        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32, padding=int(32 * 0.125), padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)]
        )

        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)


class CIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


class CIFAR100SSL(datasets.CIFAR100):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


# DATASET_GETTERS = {'cifar10': get_cifar10,
#                    'cifar100': get_cifar100}

if __name__ == '__main__':
    num_labeled = 4000
    num_classes = 10
    expand_labels = True
    batch_size = 32
    eval_step = 1024

    train_labeled_dataset, train_unlabeled_dataset, test_dataset = get_cifar10_fixmatch(num_labeled,
                                                                                        batch_size,
                                                                                        r"./data",)

    image, labels = iter(train_labeled_dataset).next()
    print(image.shape)
    print(labels)
    (image1, image2), labels = iter(train_unlabeled_dataset).next()
    print(image1.shape)
    print(labels)

    print(len(train_labeled_dataset) * batch_size)
    print(len(train_unlabeled_dataset) * batch_size * 7)
    print(len(test_dataset))
